{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "948fd260",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "# Layers and Modules\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3b911ed7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:38.169811Z",
     "iopub.status.busy": "2023-08-18T19:31:38.169219Z",
     "iopub.status.idle": "2023-08-18T19:31:40.246403Z",
     "shell.execute_reply": "2023-08-18T19:31:40.245375Z"
    },
    "origin_pos": 3,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdc508a8",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "To begin, we revisit the code\n",
    "that we used to implement MLPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7df830d6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:40.251527Z",
     "iopub.status.busy": "2023-08-18T19:31:40.250671Z",
     "iopub.status.idle": "2023-08-18T19:31:40.284734Z",
     "shell.execute_reply": "2023-08-18T19:31:40.283757Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 10])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))\n",
    "\n",
    "X = torch.rand(2, 20)\n",
    "net(X).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3cf474b",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "`nn.Sequential` defines a special kind of `Module`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41e206ef",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "A Custom Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9c5f010e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:40.289115Z",
     "iopub.status.busy": "2023-08-18T19:31:40.288828Z",
     "iopub.status.idle": "2023-08-18T19:31:40.295756Z",
     "shell.execute_reply": "2023-08-18T19:31:40.294461Z"
    },
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden = nn.LazyLinear(256)\n",
    "        self.out = nn.LazyLinear(10)\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.out(F.relu(self.hidden(X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "652c0c95",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Instantiate the MLP's layers\n",
    "and subsequently invoke these layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8c8301dc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:40.300597Z",
     "iopub.status.busy": "2023-08-18T19:31:40.300120Z",
     "iopub.status.idle": "2023-08-18T19:31:40.308051Z",
     "shell.execute_reply": "2023-08-18T19:31:40.307090Z"
    },
    "origin_pos": 20,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 10])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MLP()\n",
    "net(X).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7220684",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "The Sequential Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2b8d5d6a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:40.323023Z",
     "iopub.status.busy": "2023-08-18T19:31:40.322454Z",
     "iopub.status.idle": "2023-08-18T19:31:40.330187Z",
     "shell.execute_reply": "2023-08-18T19:31:40.329340Z"
    },
    "origin_pos": 31,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 10])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MySequential(nn.Module):\n",
    "    def __init__(self, *args):\n",
    "        super().__init__()\n",
    "        for idx, module in enumerate(args):\n",
    "            self.add_module(str(idx), module)\n",
    "\n",
    "    def forward(self, X):\n",
    "        for module in self.children():\n",
    "            X = module(X)\n",
    "        return X\n",
    "\n",
    "net = MySequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))\n",
    "net(X).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7adc85f",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Executing Code in the Forward Propagation Method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ede5347f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:40.344398Z",
     "iopub.status.busy": "2023-08-18T19:31:40.343674Z",
     "iopub.status.idle": "2023-08-18T19:31:40.355810Z",
     "shell.execute_reply": "2023-08-18T19:31:40.353856Z"
    },
    "origin_pos": 40,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.3836, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class FixedHiddenMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.rand_weight = torch.rand((20, 20))\n",
    "        self.linear = nn.LazyLinear(20)\n",
    "\n",
    "    def forward(self, X):\n",
    "        X = self.linear(X)\n",
    "        X = F.relu(X @ self.rand_weight + 1)\n",
    "        X = self.linear(X)\n",
    "        while X.abs().sum() > 1:\n",
    "            X /= 2\n",
    "        return X.sum()\n",
    "\n",
    "net = FixedHiddenMLP()\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "935d499c",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Mix and match various\n",
    "ways of assembling modules together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c0c3d190",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:40.359588Z",
     "iopub.status.busy": "2023-08-18T19:31:40.359258Z",
     "iopub.status.idle": "2023-08-18T19:31:40.372492Z",
     "shell.execute_reply": "2023-08-18T19:31:40.371497Z"
    },
    "origin_pos": 44,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0679, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(nn.LazyLinear(64), nn.ReLU(),\n",
    "                                 nn.LazyLinear(32), nn.ReLU())\n",
    "        self.linear = nn.LazyLinear(16)\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.linear(self.net(X))\n",
    "\n",
    "chimera = nn.Sequential(NestMLP(), nn.LazyLinear(20), FixedHiddenMLP())\n",
    "chimera(X)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "language_info": {
   "name": "python"
  },
  "required_libs": [],
  "rise": {
   "autolaunch": true,
   "enable_chalkboard": true,
   "overlay": "<div class='my-top-right'><img height=80px src='http://d2l.ai/_static/logo-with-text.png'/></div><div class='my-top-left'></div>",
   "scroll": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}